import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from datasets.frozen_embeddings.embeddings.nlp import embed_text_bert

from experiments.experimental_pipeline import DEFAULT_CACHE, DEFAULT_RESULTS
from datasets.frozen_embeddings.loader import EmbeddingDataset

# --- Setup Paths ---
DATASET_NAME = EmbeddingDataset.IMDB
results_path = DEFAULT_RESULTS / DATASET_NAME
os.makedirs(results_path, exist_ok=True)

debug_path = DEFAULT_CACHE / DATASET_NAME / "debug"
os.makedirs(debug_path, exist_ok=True)

# Subdir for plots
plot_path = results_path / "augmentation_analysis"
os.makedirs(plot_path, exist_ok=True)

# --- Load Transformations CSV ---
csv_path = results_path / "review_transformations.csv"
df = pd.read_csv(csv_path)
assert "Transformation" in df.columns and "Review" in df.columns

transformations = df["Transformation"].tolist()
reviews = df["Review"].tolist()

# --- Compute BERT Embeddings ---
print("Computing BERT embeddings...")
embeddings = embed_text_bert(reviews)
# Add intercept term
embeddings_with_intercept = np.hstack([embeddings, np.ones((embeddings.shape[0], 1), dtype=embeddings.dtype)])
embedding_dict = dict(zip(transformations, embeddings_with_intercept))


# --- Original Embedding ---
original_vec = embedding_dict["Original"]
original_norm_sq = np.dot(original_vec, original_vec)

# --- Load Hessian Inverse ---
hessian_file = results_path / "imdb_hessian_and_inverse.npz"
hessian_data = np.load(hessian_file)
h_inv = hessian_data["hessian_inv"]

# --- Compute Metrics ---
projection_fractions = {}
perpendicular_fractions = {}
dot_with_original = {}
hinv_ip_with_original = {}
hinv_ip_self = {}

for name, vec in embedding_dict.items():
    # Inner product with original
    dot = np.dot(vec, original_vec)
    dot_with_original[name] = dot

    # Projection fraction
    proj_coeff = dot / original_norm_sq
    proj_vec = proj_coeff * original_vec
    proj_mag = np.linalg.norm(proj_vec)
    vec_mag = np.linalg.norm(vec)
    parallel_frac = proj_mag / vec_mag
    perpendicular_frac = np.sqrt(1 - parallel_frac**2)

    projection_fractions[name] = parallel_frac
    perpendicular_fractions[name] = perpendicular_frac

    # Hessian inverse inner products
    hinv_ip_with_original[name] = vec.T @ h_inv @ original_vec
    hinv_ip_self[name] = vec.T @ h_inv @ vec

# --- Plotting Function ---
def plot_bar(data_dict, title, ylabel, filename):
    items = sorted(data_dict.items())
    labels, values = zip(*items)
    plt.figure(figsize=(10, 6))
    plt.bar(labels, values)
    plt.xticks(rotation=45, ha="right")
    plt.title(title)
    plt.ylabel(ylabel)
    plt.tight_layout()
    plt.savefig(plot_path / filename)
    plt.close()

# --- Save Plots ---
plot_bar(dot_with_original,
         "Dot Product with Original Embedding",
         "Dot Product",
         "dot_with_original.png")

plot_bar(perpendicular_fractions,
         "Perpendicular Component Fraction",
         "Fraction",
         "perpendicular_fraction.png")

plot_bar(hinv_ip_with_original,
         "H^-1 Inner Product with Original",
         "uᵀH⁻¹v",
         "h_inv_ip_with_original.png")

plot_bar(hinv_ip_self,
         "H^-1 Inner Product with Self",
         "uᵀH⁻¹u",
         "h_inv_ip_self.png")

print(f"Plots saved to: {plot_path}")
